{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Watch a Smart Agent!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.Start the Environment for Trained Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import gym\n",
    "import os\n",
    "import time\n",
    "\n",
    "from agent import Agent\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "def rgb2gray(rgb, norm=True):\n",
    "        # rgb image -> gray [0, 1]\n",
    "    gray = np.dot(rgb[..., :], [0.299, 0.587, 0.114])\n",
    "    if norm:\n",
    "        # normalize\n",
    "        gray = gray / 128. - 1.\n",
    "    return gray\n",
    "\n",
    "seed = 0\n",
    "img_stack = 4\n",
    "action_repeat = 10\n",
    "env = gym.make('CarRacing-v0', verbose=0)\n",
    "state = env.reset()\n",
    "reward_threshold = env.spec.reward_threshold"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Wrapper():\n",
    "    \"\"\"\n",
    "    Environment wrapper for CarRacing \n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, env):\n",
    "        self.env = env  \n",
    "\n",
    "    def reset(self):\n",
    "        self.counter = 0\n",
    "        self.av_r = self.reward_memory()\n",
    "\n",
    "        self.die = False\n",
    "        img_rgb = env.reset()\n",
    "        img_gray = rgb2gray(img_rgb)\n",
    "        self.stack = [img_gray] * img_stack  # four frames for decision\n",
    "        return np.array(self.stack)\n",
    "\n",
    "    def step(self, action):\n",
    "        total_reward = 0\n",
    "        for i in range(action_repeat):\n",
    "            img_rgb, reward, die, _ = env.step(action)\n",
    "            # don't penalize \"die state\"\n",
    "            if die:\n",
    "                reward += 100\n",
    "            # green penalty\n",
    "            if np.mean(img_rgb[:, :, 1]) > 185.0:\n",
    "                reward -= 0.05\n",
    "            total_reward += reward\n",
    "            # if no reward recently, end the episode\n",
    "            done = True if self.av_r(reward) <= -0.1 else False\n",
    "            if done or die:\n",
    "                break\n",
    "        img_gray = rgb2gray(img_rgb)\n",
    "        self.stack.pop(0)\n",
    "        self.stack.append(img_gray)\n",
    "        assert len(self.stack) == img_stack\n",
    "        return np.array(self.stack), total_reward, done, die\n",
    "\n",
    "\n",
    "    @staticmethod\n",
    "    def reward_memory():\n",
    "        # record reward for last 100 steps\n",
    "        count = 0\n",
    "        length = 100\n",
    "        history = np.zeros(length)\n",
    "\n",
    "        def memory(reward):\n",
    "            nonlocal count\n",
    "            history[count] = reward\n",
    "            count = (count + 1) % length\n",
    "            return np.mean(history)\n",
    "\n",
    "        return memory\n",
    "    \n",
    "agent = Agent(device)\n",
    "\n",
    "env_wrap = Wrapper(env)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Prepare Load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load(agent, directory, filename):\n",
    "    agent.net.load_state_dict(torch.load(os.path.join(directory,filename)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Prepare Player"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "import os\n",
    "\n",
    "def play(env, agent, n_episodes):\n",
    "    state = env_wrap.reset()\n",
    "    \n",
    "    scores_deque = deque(maxlen=100)\n",
    "    scores = []\n",
    "    \n",
    "    for i_episode in range(1, n_episodes+1):\n",
    "        state = env_wrap.reset()        \n",
    "        score = 0\n",
    "        \n",
    "        time_start = time.time()\n",
    "        \n",
    "        while True:\n",
    "            action, a_logp = agent.select_action(state)\n",
    "            env.render()\n",
    "            next_state, reward, done, die = env_wrap.step( \\\n",
    "                action * np.array([2., 1., 1.]) + np.array([-1., 0., 0.]))\n",
    "\n",
    "            state = next_state\n",
    "            score += reward\n",
    "            \n",
    "            if done or die:\n",
    "                break \n",
    "\n",
    "        s = (int)(time.time() - time_start)\n",
    "        \n",
    "        scores_deque.append(score)\n",
    "        scores.append(score)\n",
    "\n",
    "        print('Episode {}\\tAverage Score: {:.2f},\\tScore: {:.2f} \\tTime: {:02}:{:02}:{:02}'\\\n",
    "                  .format(i_episode, np.mean(scores_deque), score, s//3600, s%3600//60, s%60))  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Load and Play: Score = 350-550"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 1\tAverage Score: 63.53,\tScore: 63.53 \tTime: 00:00:04\n",
      "Episode 2\tAverage Score: 305.90,\tScore: 548.28 \tTime: 00:00:10\n",
      "Episode 3\tAverage Score: 370.60,\tScore: 500.00 \tTime: 00:00:11\n",
      "Episode 4\tAverage Score: 366.48,\tScore: 354.09 \tTime: 00:00:07\n",
      "Episode 5\tAverage Score: 304.39,\tScore: 56.03 \tTime: 00:00:05\n"
     ]
    }
   ],
   "source": [
    "load(agent, 'dir_chk', 'model_weights_350-550.pth')\n",
    "play(env, agent, n_episodes=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Load and Play: Score = 580-660"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 1\tAverage Score: 603.72,\tScore: 603.72 \tTime: 00:00:12\n",
      "Episode 2\tAverage Score: 593.94,\tScore: 584.16 \tTime: 00:00:11\n",
      "Episode 3\tAverage Score: 432.31,\tScore: 109.06 \tTime: 00:00:08\n",
      "Episode 4\tAverage Score: 480.99,\tScore: 627.01 \tTime: 00:00:11\n",
      "Episode 5\tAverage Score: 517.67,\tScore: 664.38 \tTime: 00:00:11\n"
     ]
    }
   ],
   "source": [
    "load(agent, 'dir_chk', 'model_weights_480-660.pth')\n",
    "play(env, agent, n_episodes=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. Load and Play: Score = 820-980"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 1\tAverage Score: 1003.80,\tScore: 1003.80 \tTime: 00:00:10\n",
      "Episode 2\tAverage Score: 958.42,\tScore: 913.04 \tTime: 00:00:11\n",
      "Episode 3\tAverage Score: 943.30,\tScore: 913.04 \tTime: 00:00:11\n",
      "Episode 4\tAverage Score: 943.02,\tScore: 942.18 \tTime: 00:00:11\n",
      "Episode 5\tAverage Score: 938.26,\tScore: 919.25 \tTime: 00:00:11\n"
     ]
    }
   ],
   "source": [
    "load(agent, 'dir_chk', 'model_weights_820-980.pth')\n",
    "play(env, agent, n_episodes=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
